import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import random
from torch import nn
import torch.nn.functional  as F
import argparse


class CNN(nn.Module):
    def __init__(self, m=50, d=1000):
        super(CNN, self).__init__()

        self.Wp = torch.nn.Parameter(torch.randn(d, m))
        self.Wp.requires_grad = True
        self.Wn = torch.nn.Parameter(torch.randn(d, m))
        self.Wn.requires_grad = True

        nn.init.normal_(self.Wp, std=0.001)
        nn.init.normal_(self.Wn, std=0.001)


    def act(self,input):
        return torch.pow((input), 2)

    def forward(self, x1, x2, verbose=False):
        Fp = torch.mean(self.act(torch.mm(x1, self.Wp)), 1) \
            + torch.mean(self.act(torch.mm(x2, self.Wp)), 1)
        Fn = torch.mean(self.act(torch.mm(x1, self.Wn)), 1) \
            + torch.mean(self.act(torch.mm(x2, self.Wn)), 1)
        out = Fp - Fn
        return out


def prepare_data():
    train_y = torch.cat((torch.ones(int(n_train/2)), -torch.ones(int(n_train/2))))
    test_y = torch.cat((torch.ones(int(n_test/2)), -torch.ones(int(n_test/2))))

    feature1 = torch.zeros(d, 1)
    feature1[0] = mu

    feature2 = torch.zeros(d, 1)
    feature2[1] = mu

    train_x1 = torch.matmul( (1+train_y.unsqueeze(0).T)/2, feature1.T) + \
                torch.matmul( (1-train_y.unsqueeze(0).T)/2, feature2.T)
    test_x1 = torch.matmul( (1+test_y.unsqueeze(0).T)/2, feature1.T) + \
                torch.matmul( (1-test_y.unsqueeze(0).T)/2, feature2.T)
    train_x2 = torch.randn(n_train, d)
    test_x2 = torch.randn(n_test, d)

    return train_x1, train_x2, train_y, test_x1, test_x2, test_y, feature1.squeeze(), feature2.squeeze()

def set_seed(seed: int):
    """set the seed
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def parse_args(args):
    parser = argparse.ArgumentParser()

    parser.add_argument("--mu", type=float, default=15)

    args = parser.parse_args(args)

    return args

if __name__ == "__main__":

    args = parse_args(None)
    seed = 5996
    n_train = 30
    n_test = 2000
    d = 1000
    n_epoch = 500
    m = 20
    mu = args.mu

    set_seed(seed)


    train_x1, train_x2, train_y, test_x1, test_x2, test_y, feature1, feature2 = prepare_data()


    model = CNN(m=m, d=d)
    sample_size = n_train
    data_loader = DataLoader(TensorDataset(
        train_x1,
        train_x2,
        train_y
    ), batch_size=int(250), shuffle=True)

    optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)

    train_loss_values = []
    test_loss_values = []
    train_acc_values = []

    train_acc_flip_values = []
    train_acc_unflip_values = []

    test_acc_values = []
    feature_learning = []
    noise_memorization_p = np.zeros(( m, n_train, n_epoch))
    feature_learning1_p = np.zeros(( m,  n_epoch))
    feature_learning2_p = np.zeros(( m,  n_epoch))

    noise_memorization_n = np.zeros(( m, n_train, n_epoch))
    feature_learning1_n = np.zeros(( m,  n_epoch))
    feature_learning2_n = np.zeros(( m,  n_epoch))

    for ep in range(n_epoch):

        feature_learning1_p[:, ep] = (torch.matmul(model.Wp.T, feature1)).detach().numpy()
        feature_learning2_p[:, ep] = (torch.matmul(model.Wp.T, feature2)).detach().numpy()
        noise_memorization_p[:, :, ep] = (torch.matmul(model.Wp.T, train_x2.T)).detach().numpy()

        feature_learning1_n[:, ep] = (torch.matmul(model.Wn.T, feature1)).detach().numpy()
        feature_learning2_n[:, ep] = (torch.matmul(model.Wn.T, feature2)).detach().numpy()
        noise_memorization_n[:, :, ep] = (torch.matmul(model.Wn.T, train_x2.T)).detach().numpy()

        f_pred_test = model.forward(test_x1, test_x2)
        f_pred_train = model.forward(train_x1, train_x2)

        train_loss = torch.log(torch.add(torch.exp(-f_pred_train * train_y), 1)).mean()
        train_loss_values.append(train_loss.item())

        pred_binary_train = (f_pred_train > 0).float() * 2 - 1
        pred_binary_test = (f_pred_test > 0).float() * 2 - 1

        correct_preds_train = (pred_binary_train == train_y).float().mean()
        correct_preds_test = (pred_binary_test == test_y).float().mean()

        train_acc_values.append(correct_preds_train.item())
        test_acc_values.append(correct_preds_test.item())

        test_loss = torch.log(torch.add(torch.exp(-f_pred_test * test_y), 1)).mean()
        test_loss_values.append(test_loss.item())

        train_loss = 0
        for sample_x1, sample_x2, sample_y in data_loader:

            model.train()
            optimizer.zero_grad()
            f_pred = model.forward(sample_x1, sample_x2)
            loss = torch.log(torch.add(torch.exp(-f_pred * sample_y), 1)).mean()

            loss.backward()
            optimizer.step()
            model.eval()
            train_loss += sample_size * loss.item()




        print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, test_loss={test_loss:0.5e}')


    noise_pseris = np.max(np.abs(noise_memorization_p), axis=0)
    noise_nseris = np.max(np.abs(noise_memorization_n), axis=0)
    noise_mseris = np.maximum(noise_pseris, noise_nseris)

    feature_pseris1 = np.max(np.abs(feature_learning1_p), axis=0)
    feature_nseris1 = np.max(np.abs(feature_learning1_n), axis=0)
    feature_mseris1 = np.maximum(feature_pseris1, feature_nseris1)

    feature_pseris2 = np.max(np.abs(feature_learning2_p), axis=0)
    feature_nseris2 = np.max(np.abs(feature_learning2_n), axis=0)
    feature_mseris2 = np.maximum(feature_pseris2, feature_nseris2)


    checkpoint = {
        'model_state_dict': model.state_dict(),
        'loss': train_loss_values,
        'train_acc': train_acc_values,
        'test_acc': test_acc_values,
        'noise': noise_mseris,
        'feature1': feature_mseris1,
        'feature2': feature_mseris2
    }

    torch.save(checkpoint, f'syn_class_{mu}.pth')